//
// Created by song on 16-12-22.
//

#ifndef C0COMPILER_FUNCTION_H
#define C0COMPILER_FUNCTION_H

#include <sstream>
#include <set>
#include "IR.h"
#include "BasicBlock.h"
#include "AbstractIREvaluator.h"


class Function{
    SymbolTable* symTable;
    map<TmpVar *, SymbolTable::VarRecord *> varMap;
    set<TmpVar*> globalVars;
public:
    IRList* irList;
    Function(IRList *irList, SymbolTable *pTable, map<TmpVar *, SymbolTable::VarRecord *> &varMap)
            : irList(irList), symTable(pTable), varMap(varMap) {
        map<TmpVar *, SymbolTable::VarRecord *>::const_iterator iter=varMap.begin();
        while(iter!=varMap.end()){
            if((*iter).second->isGlobal){
                globalVars.insert((*iter).first);
            }
            iter++;
        }
    }

private:
    vector<BasicBlock*> blockList;
    map<TmpVar*, BasicBlock*> blockLabelMap;
    set<IRType> jumpOps = { JMP, JE, JNE, JG, JGE, JL, JLE };
    set<std::pair<BasicBlock*,BasicBlock*>> yesCondLink;// is a subset of blockLinks. only contains links from condition blocks to their yes result.
    set<std::pair<BasicBlock*,BasicBlock*>> blockLinks;

    bool isJump(IR *pIR) {
        return jumpOps.count(pIR->op) > 0;
    }

    void addBlock(BasicBlock*block) {
        if(blockList.empty()){
            blockList.push_back(block);
        }else{
            BasicBlock* lastBlock = blockList[blockList.size()-1];
            lastBlock->next = block;
            block->pre = lastBlock;
            blockList.push_back(block);
            if(block->label()!=NULL){
                blockLabelMap[block->label()] = block;
            }
        }
    }

    void updateBlockLinks() {
        for(int i=0; i<blockList.size(); i++){
            BasicBlock* b = blockList[i];
            IR* ir = b->lastIR();
            if(isJump(ir)){
                TmpVar* targetLabel = ir->result;
                if(blockLabelMap.count(targetLabel)>0){
                    BasicBlock* targetBlock = blockLabelMap[targetLabel];
                    b->to.insert(targetBlock);
                    targetBlock->from.insert(b);
                    yesCondLink.insert(std::make_pair(b, targetBlock));
                    blockLinks.insert(std::make_pair(b, targetBlock));
                }else{
                    Error::nextErrorDetail<<"try jump to an un-labeled block";
                    Error::internal(Error::Should_Not_Happen);
                }
                if(ir->op!=JMP && b->next!=NULL){
                    b->to.insert(b->next);
                    b->next->from.insert(b);
                    blockLinks.insert(std::make_pair(b, b->next));
                }
            }else if(ir->op!=RETURN && b->next!=NULL){
                b->to.insert(b->next);
                b->next->from.insert(b);
                blockLinks.insert(std::make_pair(b, b->next));
            }
        }
    }

    map<BasicBlock*, set<IR*>*> in;
    map<BasicBlock*, set<IR*>*> out;
    map<BasicBlock*, set<IR*>*> gen;
    map<BasicBlock*, set<TmpVar*>*> varGen;

    void reachingDefinition(){
//        AbstractIREvaluator irEvaluator(in, out, gen, blockList);
//        irEvaluator.eval();
        for(int i=0; i<blockList.size(); i++) {
            BasicBlock *b = blockList[i];
            set<TmpVar*>* varGenerated = new set<TmpVar*>();
            gen[b] = b->dataFlowGen(varGenerated);
            varGen[b] = varGenerated;
            in[b] =  new set<IR*>();
            out[b] = new set<IR*>();
        }
        updateBlockInOutSetUntilConverge();
        findAndMergeLiveRanges();
    }

    void findAndMergeLiveRanges(){
        for(int i=0; i<blockList.size(); i++) {
            BasicBlock *b = blockList[i];
//            cout <<"IN :"<< printSet(in[b]) << endl;
//            cout <<"GEN:"<< printSet(gen[b]) << endl;
//            cout <<"OUT:"<< printSet(out[b]) << endl;
            b->findVarDef(in[b]);
        }
    }

    void updateBlockInOutSetUntilConverge(){
        bool inUpdate, outUpdate;
        int loopCount=0;
        do{
            inUpdate=false;
            outUpdate=false;
            for(int i=0; i<blockList.size(); i++) { // update each block's in and out
                BasicBlock *b = blockList[i];
//                cout << b->name() << printSet(in[b]) << endl;
//                cout << b->name() << printSet(out[b]) << endl;
                if(updateBlockInSet(b)){
                    inUpdate=true;
                }
                if(updateBlockOutSet(b)){
                    outUpdate=true;
                }
//                cout << b->name() << printSet(in[b]) << endl;
//                cout << b->name() << printSet(out[b]) << endl;
//                cout << inUpdate << " " << outUpdate << endl;
//                char line[20];
//                std::cin.getline(line, 20);
            }
            loopCount++;
        }while(inUpdate || outUpdate);
//        cout << "all set fixed after "<< loopCount-1 << " loops." << endl;
    }

    bool updateBlockInSet(BasicBlock* b){
        bool isInSetUpdate=false;
        set<IR*>* inSet = in[b];
        for(set<BasicBlock*>::const_iterator iter=b->from.begin();
            iter!=b->from.end(); iter++){ // loop over the block's all [FROM] block
            set<IR*>* preOut = out[(*iter)]; // and get their OUT live ranges(lr).
            for(set<IR*>::const_iterator iter1=preOut->begin(); // put all the lr into this block's IN set.
                iter1!=preOut->end(); iter1++){
                IR* ir=*iter1;
                if(inSet->count(ir)==0){
                    inSet->insert(ir);
                    isInSetUpdate=true;
                }
            }
        }
        return isInSetUpdate;
    }

    bool updateBlockOutSet(BasicBlock* b){
        bool isOutSetUpdated=false;
        // First
        // 1. collect all var generated by block b
        // 2. put all IR generated to out[b]
        set<IR*>::const_iterator iter = gen[b]->begin();
        while(iter!=gen[b]->end()){
            IR* ir=*iter;
            if(out[b]->count(ir)==0){ // in case of invoke this function multiple times.
                isOutSetUpdated=true;
                out[b]->insert(*iter);
            }
            iter++;
        }
        // Second
        // put all IR of in[b] into out[b]
        // Except: the IR which the var is redefined. in this case, do nothing.
        set<IR*>::const_iterator j = in[b]->begin();
        while(j!=in[b]->end()){
            IR* ir = (*j);
            if(out[b]->count(ir)==0 && varGen[b]->count(ir->result)==0){ // not redefined.
                isOutSetUpdated=true;
                out[b]->insert(ir);
            }
            j++;
        }
        return isOutSetUpdated;
    }

    void printBlock(){
        set<string> pairHashSet;
        map<string, bool> conditionBlocks;

        for(int i=0; i<blockList.size(); i++){
            BasicBlock *b = blockList[i];
            string id = blockID(b);
            char type=id[0];
            switch(type){
                case 'C':
                    cout << id << "=>condition: " << b->name() << endl;
                    conditionBlocks[id]=true;
                    break;
                case 'B':
                    cout << id << "=>operation: " << b->name() << endl;
                    break;
                case 'E':
                    cout << id << "=>operation: " << b->name() <<"(return)"<< endl;
                    break;
                default:
                    Error::nextErrorDetail<<"block type is not S|C|B|E~";
                    Error::internal(Error::Should_Not_Happen);
            }
        }
        cout << "end=>end" << endl;

        set<std::pair<BasicBlock*, BasicBlock*>>::const_iterator iter=blockLinks.begin();
        while(iter!=blockLinks.end()){
            std::pair<BasicBlock*, BasicBlock*> pairLink = *iter;
            BasicBlock* from=pairLink.first;
            BasicBlock* to = pairLink.second;
            if(conditionBlocks[blockID(from)]) {
                if(yesCondLink.count(pairLink)>0){
                    cout << blockID(from) << "(yes)->" << blockID(to) << endl;
                }else{
                    cout << blockID(from) << "(no)->" << blockID(to) << endl;
                }
            }else{
                cout << blockID(from) << "->" << blockID(to) << endl;
            }
            iter++;
        }

        for(int i=0; i<blockList.size(); i++){
            string bID = blockID(blockList[i]);
            if(bID[0]=='E'){
                cout << bID << "->end" << endl;
            }
        }
        cout << endl;
    }

    string blockID(BasicBlock* b){
        stringstream r;
        r << '_' << std::hex << b->id();
        string id = r.str();//.substr(5,9);
        long outBlocks = b->to.size();
        if(outBlocks==2){
            return 'C'+id;
        }else if(outBlocks==1){
            return 'B'+id;
        }else if(outBlocks==0){
            return 'E'+id;
        }else{
            Error::nextErrorDetail<<"jump to more than 2 place? impossible!";
            Error::internal(Error::Should_Not_Happen);
        }
    }

    string printSet(set<IR *> * &s)const{
        std::stringstream r;
        set<IR*>::const_iterator iter = s->begin();
        while(iter!=s->end()){
            IR* ir = *iter;
//            string ir = (*iter)->print();
//            r << Utils::removeSpaces(ir) << "  |||  ";
            r << " | " << ir->result->name();
            iter++;
        }
        return r.str();
    }

public:
    void cutToBlock(IR* first, IR* last){
        Function* function;
        IRList* v = irList->getIRListBetween(first, last);
//        cout << v->toString() << endl;
//        return ;
        vector<IR*> irs = v->clist;
        vector<int> marker; // after the IR, the block end.
        for(int i=0; i<irs.size(); i++) {
            IR *ir = irs[i];
            bool begin = (ir->label!=NULL && ir->op!=FUNBEGIN);
            bool end = (isJump(ir) || ir->op == CALL || ir->op == FUNEND);
            bool funEnd = (ir->op == FUNEND);
            if (begin) {
                if(marker.empty() || marker[marker.size()-1]!=i) {
                    marker.push_back(i);
                }
            }
            if (end) {
                if(marker.empty() || marker[marker.size()-1]!=i+1) {
                    marker.push_back(i+1);
                }
            }
        }

        int j=0;
        BasicBlock * currentBlock = new BasicBlock(globalVars);
        for(int i=0; i< irs.size(); i++){
            currentBlock->appendIR(irs[i]);
            if(i==(marker[j]-1)){
                addBlock(currentBlock);
                if(j<(marker.size()-1)){
                    currentBlock = new BasicBlock(globalVars);
                    j++;
                }else{
                    break;
                }
            }
        }

        updateBlockLinks();
        reachingDefinition();
        if(Config::printControlFlowChart) printBlock();
    }

    string toString(){
        stringstream r;
        r << "################################################" << endl;
        for(int i=0; i<blockList.size(); i++){
            if(i>0) r << "------------------------------------------" << endl;
            r << blockList[i]->toString();
        }
        r << "################################################" << endl;
        return r.str();
    }


};


#endif //C0COMPILER_FUNCTION_H
